LongNet – Scaling Transformers to 1,000,000,000 Tokens

Large Language Models
Author

Maxime Lbonne

Published

April 7, 2024

Tip

This paper introduces the dilated attention mechanism, another sparse attention scheme which approximates sparse attention.

📝 Paper: https://arxiv.org/pdf/2307.02486.pdf

The authors claim their technique can scale up to 1 billion tokens. However, they only show results up to 32K tokens and assume you can use 10,000 GPUs.

Dilated Attention

N is the sequence length and d is the hidden dimension.

Vanilla attention maps a query and a set of keys and values to output: O = \text{softmax}(QK^\top)V Sparse attention restricts the query’s access to a subset of keys and values to reduce the complexity: O = \text{softmax}(QK^\top \odot \mathbb{1}_S)V Dilated attention splits the input (Q, K, V) into segments \{(\widetilde{Q}_i, \widetilde{K}_i, \widetilde{V}_i)\}^{\frac{N}{w}} equally with a segment length w. Each segment is then sparsified along the sequence dimension by selecting the rows with an interval r. The computation can be written as:

\begin{align*} \widetilde{Q}_i &= [Q_{iw}, Q_{iw+r}, Q_{iw+2r}, \ldots, Q_{(i+1)w-1}] \\ \widetilde{K}_i &= [K_{iw}, K_{iw+r}, K_{iw+2r}, \ldots, K_{(i+1)w-1}] \\ \widetilde{V}_i &= [V_{iw}, V_{iw+r}, V_{iw+2r}, \ldots, V_{(i+1)w-1}] \end{align*} The sparsified segments \{(\widetilde{Q}_i, \widetilde{K}_i, \widetilde{V}_i)\}^{\frac{N}{w}} are fed into the attention in parallel, after which they are scattered and concatenated as the output O:

\begin{align*} \widetilde{O}_i &= \text{softmax}(\widetilde{Q}_i \widetilde{K}^\top_{i}) \widetilde{V}_i \\ \hat{O}_i &= \{\widetilde{O}_{i,j} | j \mod r = 0; 0 | j \mod r \neq 0\} \\ O &= [\hat{O}_0, \hat{O}_1, \dots, \hat{O}_{\frac{N}{w} - 1}] \end{align*} The idea is quite visual: you have vanilla attention between tokens from the same segment, and then a sparse mechanism between tokens from different segments as follows:

[!note] Why this particular scheme and not a more hierarchical segmentation? There is no real explanation behind this solution.

Interestingly, they try to fix the errors induced by the sparse representation with the multi-head attention as follows:

Distributed Training

The authors implement a distributed training algorithm that motivates the claim about a sequence length of 1 billion tokens. The input is split into segments across multiple GPUs, then projected into queries, keys, and values on each device.

The algorithm performs the attention computation locally for segment lengths smaller or equal to the sequence length on a given device. For larger segment lengths, keys and values are distributed across devices and collected prior to attention computation.

This method includes an all-gather operation for key-value pairs collection, resulting in constant communication costs, independent of sequence length. The outputs from different devices are concatenated to form the final attention output.

The performance of the LongNet model, compared to vanilla attention, demonstrates its efficient scaling due to the linear complexity of dilated attention and its distributed algorithm, even when sequence lengths increase dramatically.

Experiments

LongNet was implemented for language modeling with a backbone architecture named Magneto and trained with The Stack dataset. The results were compared with the vanilla and sparse Transformers for sequence lengths from 2K to 32K. The models were tested on different sequence lengths, and LongNet consistently performed better than the other models, especially when the sequence length during training was increased.

LongNet also outperformed when the context length was scaled up during training. While both LongNet and vanilla Transformers benefited from larger context lengths, LongNet achieved lower test loss with less computation, demonstrating its efficiency in learning long-range dependencies.

Scaling up the model size, LongNet followed the power law, showing that the dense Transformer is not a prerequisite for scaling language models. Additionally, LongNet was found to be more scalable and efficient.

The model was also tested for longer context prompting, with the length of the prompt scaled from 2K to 32K. Results showed that as the context window grew, the test loss of LongNet gradually decreased, indicating its superior ability to leverage longer contexts to improve language modeling.